import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats.mstats import winsorize

compdat = pd.read_stata('data/Compustat/Compustat_clean_with_lags.dta')
compdat['const'] = 1

def groupby_cov_fn(x, var1, var2, wt):
	y = x.loc[~pd.isna(x[[wt, var1, var2]]).any(axis=1)]
	norm = y[wt].mean()
	return (y[wt]*y[var1]*y[var2]).mean()/norm - ((y[wt]*y[var1]).mean()/norm)*((y[wt]*y[var2]).mean()/norm)

def groupby_var_fn(x, var, wt):
	y = x.loc[~pd.isna(x[[wt, var]]).any(axis=1)]
	norm = y[wt].mean()
	if norm <= 0:
		return np.nan
	return (y[wt]*y[var]*y[var]).mean()/norm - ((y[wt]*y[var]).mean()/norm)**2

def clip_series(series, quantile_low, quantile_high):
	return series.clip(lower=series.quantile(quantile_low), upper=series.quantile(quantile_high))

def clip_with_groups(df, series_name, groupby, quantile_low, quantile_high):
	low = df.groupby(groupby)[series_name].transform(lambda x: x.quantile(quantile_low))
	high = df.groupby(groupby)[series_name].transform(lambda x: x.quantile(quantile_high))
	series_clean = df[series_name].copy()
	masked = df[series_name].copy()
	masked[~pd.isna(masked)] = 0
	masked.loc[series_clean <= low] = 1 
	series_clean.loc[series_clean <= low] = low 
	masked.loc[series_clean >= high] = 1 
	series_clean.loc[series_clean >= high] = high 
	return series_clean,masked

lags = list(range(-5, 18))
lag_str = ['m' + str(-l) if l<0 else str(l) for l in lags]

mu_list = ['mu_li_ms4', 'mu_cfg1_ms4', 'mu_cfg2_ms4']
groupby_vars_list = [['quarter'], ['quarter', 'naics3']]

sales_list = compdat.groupby(['quarter', 'naics3'])[['sale_ms4', 'saleq', 'const']].sum().reset_index()
sales_list = sales_list.rename(columns={'const':'n_firms'})
sales_list.to_stata('data/data_temp/aggregated_sales_list.dta')

for l in lag_str:
	compdat['dlog_sale_ms4_win_' + l + '_w1'] = clip_series(compdat['dlog_sale_ms4_win_' + l], 0.05, 0.95)

mu_suffixes = ['']

for groupby_vars in groupby_vars_list:
	for mu_var in mu_list:
		for mu_suffix in mu_suffixes:
			dlogmu_cov_list = []
			sales_cov_list = []
			saleq_cov_list = []
			cost_cov_list = []
			costq_cov_list = []
			logmu_var_list = []
			for l in lag_str:
				print(l)
				compdat['neg_inv_mu'] = -1/compdat[mu_var + mu_suffix]
				compdat['dlog_' + mu_var + '_' + l + '_w1'] = clip_series(compdat['dlog_' + mu_var + '_' + l], 0.05, 0.95)
				dlogmu_cov_list.append(compdat.groupby(groupby_vars).apply(lambda x: groupby_cov_fn(x, 'neg_inv_mu', 'dlog_' + mu_var + '_' + l + '_w1', 'sale_ms4')))
				sales_cov_list.append(compdat.groupby(groupby_vars).apply(lambda x: groupby_cov_fn(x, 'neg_inv_mu', 'dlog_sale_ms4_win_' + l + '_w1', 'sale_ms4')))
			cov_dlogmu = pd.concat(dlogmu_cov_list, axis=1)
			cov_sales = pd.concat(sales_cov_list, axis=1)
			cov_dlogmu.columns = ['dlogmu_cov_' + mu_var + '_' + l for l in lag_str] 
			cov_sales.columns = ['sales_cov_' + mu_var + '_' + l for l in lag_str] 
			all_measures = pd.concat([cov_dlogmu, cov_sales], axis=1)
			all_measures.to_stata('data/data_temp/aggregated_' + mu_var + mu_suffix + '_' + ''.join(groupby_vars) + '.dta')
